import json
import time

from fastapi import WebSocket, WebSocketDisconnect
from typing import Union
import asyncio
import aioredis
from app.conf import settings

redis_conf = {
    'host': settings.REDIS_HOST,
    'port': settings.REDIS_PORT,
    'password': settings.REDIS_PASSWORD,
    'db': 0,
}


# def get_redis_client():

redis_client = aioredis.from_url(
        "redis://{}".format(redis_conf['host']), db=redis_conf['db'], password=redis_conf['password'],
        port=redis_conf['port'], encoding="utf-8", decode_responses=True
    )
    # return redis_client

# redis_client_send = redis.StrictRedis(**redis_conf)
# pubsub = redis_client.pubsub()


async def publish_message(topic: str, message: Union[str, dict], message_type: str = "default"):
    pool = aioredis.from_url(
        "redis://{}".format(redis_conf['host']), db=redis_conf['db'], password=redis_conf['password'],
        port=redis_conf['port'], encoding="utf-8", decode_responses=True
    )
    # psub = pool.pubsub()
    send_msg = json.dumps({
        "message": message,
        "message_type": message_type,
        "message_id": str(time.time())
    })
    await pool.publish(topic, send_msg)
    # ret = await redis_client.publish(topic, data)
    # print(topic, send_msg)


class ConnectionManager:
    def __init__(self):
        # 保存当前所有的链接的websocket对象
        self.websocket_connections = {}

    async def connect(self, websocket: WebSocket, client_id):
        # 添加连接并发送欢迎消息
        await websocket.accept()
        if self.websocket_connections.get(client_id):
            li = self.websocket_connections[client_id]
            li.append(websocket)
            self.websocket_connections[client_id] = li
        else:
            self.websocket_connections[client_id] = [websocket]

        loop = asyncio.get_event_loop()
        loop.create_task(register_pubsub(client_id))

        try:
            # 处理消息
            while True:
                # 获取信息
                # websocket.receive()
                message = await websocket.receive_text()
                # 处理发送信息
                await self.handle_websocket_message(message, client_id)

        except WebSocketDisconnect:
            # 连接断开时移除连接
            li = list(filter(lambda item: item != websocket, self.websocket_connections[client_id]))
            if li:
                self.websocket_connections[client_id] = li
            else:
                del self.websocket_connections[client_id]

    async def handle_websocket_message(self, message: Union[dict, str], client_id: str, message_type: str = "default"):

        # 处理消息
        recipient_conn = self.websocket_connections.get(client_id)

        if recipient_conn:
            for conn in recipient_conn:
                # 在线
                await conn.send_json({"type": message_type,
                                      "sender": client_id,
                                      "msg": message,
                                    })

    async def broadcast(self, message: dict):
        # 循环变量给所有在线激活的链接发送消息-全局广播
        for _, connection in self.websocket_connections:
            for conn in connection:
                await conn.send_text(message)

    async def close(self, websocket: WebSocket, client_id):
        # 断开客户端的链接
        await websocket.close()
        li = list(filter(lambda item: item != websocket, self.websocket_connections[client_id]))
        if li:
            self.websocket_connections[client_id] = li
        else:
            del self.websocket_connections[client_id]

    async def disconnect(self, user_id):
        websocket: WebSocket = self.websocket_connections[user_id]
        await websocket.close()
        li = list(filter(lambda item: item != websocket, self.websocket_connections[user_id]))
        if li:
            self.websocket_connections[user_id] = li
        else:
            del self.websocket_connections[user_id]


websocket_manager = ConnectionManager()

last_message = ''


async def reader(channel, client_id):
    global last_message
    # 进行消息的消费
    async for msg in channel.listen():  # 监听通道
        # print(msg)
        msg_data = msg.get("data")
        if msg_data and isinstance(msg_data, str):
            msg_data_dict = json.loads(msg_data)
            if last_message == msg_data_dict.get('message_id'):
                return
            message = msg_data_dict.get('message', {})
            message_type = msg_data_dict.get('message_type', 'default')
            last_message = msg_data_dict.get('message_id')
            # 进行消息处理
            await websocket_manager.handle_websocket_message(message, client_id, message_type)


async def register_pubsub(client_id):
    pool = redis_client
    psub = pool.pubsub()

    async with psub as p:
        # 消息订阅
        await p.subscribe(client_id)
        await reader(p, client_id)
        await p.unsubscribe(client_id)


